{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# JAX/jaxlib should be both 0.3.25\n",
    "# because newer JAX versions are *not* supported on TPU runtimes\n",
    "# Flax should be included in a ƒresh kernel.\n",
    "!pip freeze | egrep 'jax|flax'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# should show 8 TPU devices\n",
    "import jax, jax.tools.colab_tpu\n",
    "jax.tools.colab_tpu.setup_tpu()\n",
    "jax.devices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sometimes it's necessary to install additional packages; but we need to keep\n",
    "# JAX/jaxlib versions pinned to what is supported by the TPU runtime...\n",
    "!pip install jax==0.3.25 jaxlib==0.3.25 flax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# in case JAX version has changed after the '!pip install`, below command should\n",
    "# show the offending packages\n",
    "!pip install -qq pipdeptree\n",
    "!pipdeptree -w silence -r -p jax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# it's possible to get dependency tree without installing packages, but this\n",
    "# usually takes some 2-3 minutes...\n",
    "!pip install -qq pipgrip\n",
    "!pipgrip --tree flax==0.6.4"
   ]
  }
 ],
 "metadata": {
  "accelerator": "TPU",
  "gpuClass": "standard",
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
